{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Defining output variables to reduce memory usage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install \"pybamm[plot,cite]\" -q    # install PyBaMM if it is not installed\n",
    "import time\n",
    "import tracemalloc\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import pybamm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Solution storage and output variables\n",
    "\n",
    "A PyBaMM model is defined as a set of differential equations. For an ordinary differential equation (ODE) model, the model is defined as\n",
    "\n",
    "$$\n",
    "\\begin{align*}\n",
    "\\frac{d\\mathbf{y}}{dt} = \\mathbf{f}(\\mathbf{y}, t)\n",
    "\\end{align*}\n",
    "$$\n",
    "\n",
    "where $\\mathbf{y}$ is the state vector, and $\\mathbf{f}$ is the vector of ODEs (PyBaMM actually uses a semi-explicit DAE equation form, but for the purposes of this notebook we can just assume we're dealing with an ODE). The state vector $\\mathbf{y}$ contains all the variables that are integrated over time by the solver.\n",
    "\n",
    "For example, we can create a DFN model with PyBaMM and look at each of the state variables individually and all together in the final concatenated state vector $\\mathbf{y}$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The model state variables are: ['Discharge capacity [A.h]', 'Throughput capacity [A.h]', 'Negative particle concentration [mol.m-3]', 'Positive particle concentration [mol.m-3]', 'Porosity times concentration [mol.m-3]']\n",
      "The concatenated state vector is a vector of shape (862, 1)\n"
     ]
    }
   ],
   "source": [
    "model = pybamm.lithium_ion.DFN()\n",
    "print(\"The model state variables are:\", [var.name for var in model.rhs.keys()])\n",
    "sim = pybamm.Simulation(model)\n",
    "sim.build()\n",
    "print(\n",
    "    \"The concatenated state vector is a vector of shape\",\n",
    "    sim.built_model.concatenated_rhs.shape,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The user is also interested in a number of output variables $\\mathbf{g}$. Once the simulation is complete and the model has been solved to get $\\mathbf{y}$, the output variables $\\mathbf{g}$ can be calculated from $\\mathbf{y}$ using a function $\\mathbf{h}$:\n",
    "\n",
    "$$\n",
    "\\begin{align*}\n",
    "\\mathbf{g} = \\mathbf{h}(\\mathbf{y})\n",
    "\\end{align*}\n",
    "$$\n",
    "\n",
    "where $\\mathbf{h}$ is a function that extracts the variables of interest from the state vector. Each model in PyBaMM has many such variables of interest, and the user can choose which ones to extract and plot. E.g. perhaps the user wants to solve the DFN model and plot the positive electrode capacity? In this case, the expression for the variable of interest is a function $h$ of the state vector $\\mathbf{y}$, and this function is defined in the DFN model using a PyBaMM expression tree:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0002777777777777778 * yz-average(x-average(Positive electrode active material volume fraction)) * Positive electrode thickness [m] * Electrode width [m] * Electrode height [m] * Maximum concentration in positive electrode [mol.m-3] * Faraday constant [C.mol-1]\n"
     ]
    }
   ],
   "source": [
    "print(model.variables[\"Positive electrode capacity [A.h]\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyBaMM doesn't evaluate this expression until the user asks for it when they extract it from the solution object. Instead, it stores the solution vector $\\mathbf{y}$ at each time point, and then evaluates the expression for the variable of interest $\\mathbf{g}$ at each time point when the user asks for it. In the code below, the evaluation of the positive electrode capacity is done by calling `solution[\"Positive electrode capacity [A.h]\"]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve the model, storing the state vector at each time step\n",
    "solution = sim.solve([0, 3600])\n",
    "\n",
    "# Extract the positive electrode capacity using the function $h(y)$ and the stored state vector $y$\n",
    "pos_elec_capacity = solution[\"Positive electrode capacity [A.h]\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "During a solve, the PyBaMM solver needs to store the solution vector $\\mathbf{y}$ at each time step. By default, PyBaMM stores the entire vector, so if the state vector is of size $n$ and the solver steps through $m$ time points, the memory usage is $O(nm)$. This can be a problem for large models, especially when running on a machine with limited memory, or when running a very long-running simulation that needs to store many time points.\n",
    "\n",
    "## Calculating output variables on-the-fly\n",
    "\n",
    "However, it is likely that the user is only interested in a small subset of the variables, and that the size of these output variables $\\mathbf{g}$ is much smaller than the size of the state vector $\\mathbf{y}$. Therefore, it could be wasteful to store the entire state vector $\\mathbf{y}$.\n",
    "To address this, PyBaMM has a feature called \"output variables\". This allows the user to specify a list of variables that they are interested in up front (i.e. before the solve). During the solve, PyBaMM will evaluate the output variables of interest at each time point and discard the rest of the state vector. At the end of the solve, the `pybamm.Solution` object will only contain the output variables of interest, as well as the full state vector at the last time point (so that other simulations can be run from this point).\n",
    "\n",
    "Let's see how this works in practice. We'll start by solving a DFN model with a long-running experimental protocol, and assume we are only interested in the terminal voltage. We setup a PyBaMM simulation as normal and solve it, and once the solver has finished we extract the terminal voltage from the solution object. To evaluate the memory usage of this simulation, we will use the `tracemalloc` library, which allows us to track the total amount of memory allocated by the Python interpreter. We'll also keep track of the time taken to solve the model using the `time` library. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time to solve:  10.803891235031188\n",
      "Total allocated size: 202.7 MB\n"
     ]
    }
   ],
   "source": [
    "model = pybamm.lithium_ion.DFN()\n",
    "solver = pybamm.IDAKLUSolver()\n",
    "experiment = pybamm.Experiment(\n",
    "    [\n",
    "        \"Discharge at 0.1 A for 1 hour\",\n",
    "        \"Charge at 0.1 A for 1 hour\",\n",
    "    ]\n",
    "    * 100\n",
    ")\n",
    "\n",
    "tracemalloc.start()\n",
    "sim = pybamm.Simulation(model, solver=solver, experiment=experiment)\n",
    "time_start = time.perf_counter()\n",
    "sol = sim.solve()\n",
    "t_eval = np.linspace(0, 3600 * 10, 100)\n",
    "voltage = sol[\"Terminal voltage [V]\"](t_eval)\n",
    "time_end = time.perf_counter()\n",
    "print(\"Time to solve: \", time_end - time_start)\n",
    "\n",
    "snapshot = tracemalloc.take_snapshot()\n",
    "top_stats = snapshot.statistics(\"lineno\")\n",
    "total_size = sum(stat.size for stat in top_stats)\n",
    "print(\"Total allocated size: %.1f MB\" % (total_size / 10**6))\n",
    "tracemalloc.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So for about 10 seconds of compute we have already generated 202 MB of data that is in memory, for longer-running simulations this can quickly become a problem. We will now use the output variables feature to reduce the memory usage. When the solver is created, we pass in a list of output variables that we are interested in. The solver will then only store these variables at each time point, rather than the entire state vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Time to solve:  10.911966408137232\n",
      "Total allocated size: 13.3 MB\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import tracemalloc\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import pybamm\n",
    "\n",
    "model = pybamm.lithium_ion.DFN()\n",
    "solver = pybamm.IDAKLUSolver(output_variables=[\"Voltage [V]\"])\n",
    "experiment = pybamm.Experiment(\n",
    "    [\n",
    "        \"Discharge at 0.1 A for 1 hour\",\n",
    "        \"Charge at 0.1 A for 1 hour\",\n",
    "    ]\n",
    "    * 100\n",
    ")\n",
    "\n",
    "tracemalloc.start()\n",
    "sim = pybamm.Simulation(model, solver=solver, experiment=experiment)\n",
    "time_start = time.perf_counter()\n",
    "sol = sim.solve()\n",
    "t_eval = np.linspace(0, 3600 * 10, 100)\n",
    "voltage = sol[\"Voltage [V]\"](t_eval)\n",
    "time_end = time.perf_counter()\n",
    "print(\"Time to solve: \", time_end - time_start)\n",
    "\n",
    "snapshot = tracemalloc.take_snapshot()\n",
    "top_stats = snapshot.statistics(\"lineno\")\n",
    "total_size = sum(stat.size for stat in top_stats)\n",
    "print(\"Total allocated size: %.1f MB\" % (total_size / 10**6))\n",
    "tracemalloc.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That is, the solve time is the same, but the memory usage has been reduced by an order of magnitude. \n",
    "\n",
    "## Limitations\n",
    "\n",
    "An obvious downside is that we can only plot the variables that we have stored, if any other variables are accessed the solution object will raise an error."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\"Cannot process variable 'X-averaged positive particle surface concentration [mol.m-3]' as it was not part of the solve. Please re-run the solve with `output_variables` set to include this variable.\"\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    sol[\"X-averaged positive particle surface concentration [mol.m-3]\"]\n",
    "except KeyError as e:\n",
    "    print(\"Error:\", e)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
